iT邦幫忙

2023 iThome 鐵人賽

DAY 9
1
AI & Data

生成式AI到底何方神聖?一窺生程式AI的真面目系列 第 9

[Day9]:生成式AI如何開發-建立GAN有甚麼SOP

  • 分享至 

  • xImage
  •  

前言

前幾天帶各位快速複習了深度學習並且介紹了許多生成模型,相信各位在瞭解了生成模型以後,也會想要動手建立自己的生成模型。接下來我想先分享自己開發生成模型的經驗,希望能夠讓各位未來在開發生成模型時能多少有一些方向,當然這些只是經驗談,無法適用於所有人也無法適用於所有模型。但希望至少可以提供給各位一些靈感。

第一步:決定任務類型與要使用的模型

基本上這一步只是要確定你的任務類型,看是想做圖像生成還是圖像修復等任務,不過這點基本上就是有想法就可以開始動手做了。

生成模型的選擇還是要考慮電腦的效能,接著才考慮要使用什麼模型,不過像我在做實驗的通常都會選擇多種模型、比較先進的模型等。另外在選擇模型時也要注意該模型的用途,以及該模型若有指定的資料輸入/輸出格式都要注意。不過通常都可以再根據自己的需求改變模型的架構等。若是不知道能用什麼模型、完全沒有頭緒也可以問問無敵的Chat-GPT,只需要輸入你的任務類型以及想用的模型類別請他幫你推薦生成模型的種類就好,不過有時候它也會唬爛不存在的模型,在詢問過後也請一定要去查證
https://ithelp.ithome.com.tw/upload/images/20230912/20151029DU45v91bJL.png

最後就是資料集的選擇,一定要確定資料集是符合需求的,而且資料集內容沒有遺漏、缺失等,以及資料集的數量是否足夠拿來訓練,若拿100張圖片之類的訓練生成模型的話基本上是訓練不出好的成果的!

訓練效果不好有時候有很大的原因是出在資料集上,如果有資料集的說明文件請務必詳讀,並且注意使用別人的資料集要小心著作權等的問題。通常研究用或練習非商業用途應該都還好,不過還是請各位注意有沒有一些版權相關說明~

第二步:匯入函式庫

這一步很簡單,根據需求匯入函式庫就好了。要注意的是各個函式庫的版本之間有沒有相容,若不相容通常都會在程式執行的時候噴出匪夷所思的bug,所以有時候噴出很奇怪的錯誤可以直接把錯誤訊息給複製貼上搜尋看看,通常都會有建議要你改變版本。另外就是使用GPU加速的也要注意CUDA版本跟TensorFlow有沒有相容!

另外程式有時會出現Warning訊息,這是當程式有些問題但不影響運行的時候會跳出的訊息,通常是函式庫作者等,在不正確使用程式,或在程式執行當下要提醒注意事項時會跳出。這個不太需要在意,只要程式能跑就好了XD能跑後再來考慮如何優化程式碼。
https://ithelp.ithome.com.tw/upload/images/20230912/20151029OmFoIgoeyd.png

通常會匯入的不外乎是這些函式庫,以及額外因應需求用到的其他函式庫。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os #檔案處理時使用的
from tensorflow.keras.preprocessing.image import ImageDataGenerator #資料增強用的
from tensorflow.keras.layers import ... #通常我的網路層都會直接從這邊import, 因為很長故省略
from tensorflow.keras.models import Model #建立模型主體的部分
from tensorflow.keras.optimizers import Adam #優化器,這邊可根據需求使用適合的優化器

第三步:資料前處理 (Data Preprocessing)

接著就是看似不起眼但能夠對訓練造成最直接影響的資料前處理,通常可以使用Numpy, Matplotlib, OpenCV等函式庫來幫忙確認資料集的可用性與做資料前處理、視覺化等。在這個階段有一些要注意的事情:

  • 資料是否有經過正規化,像素值通常會正規化使其值域落在0~1或者-1~1,這兩個值域分別是要對應激活函數Sigmoid與Tanh的。

  • 資料的shape是多少,這在設計模型的輸入層時會設定到,另外也可以留意shape在模型演算過程中的變化,有時候因為模型有計算錯誤等原因會導致shape不匹配而讓程式出錯誤。所以shape的控制要特別注意。通常訓練的shape格式都是(資料數目, 圖片寬, 圖片高, 色彩通道數)

  • 圖片的內容是否有瑕疵,建議使用Matplotlib等函式庫將圖片顯示出來,若是圖片資料的話則可以直接開啟圖片檔案檢查。

  • 若是訓練圖像轉圖像之類的模型的話,要注意有些模型規定資料須成對,此時要確認資料是否有成對。

  • 如果資料數量太少的話,可以使用OpenCV, Numpy等進行資料增強,甚至Keras也有資料增強的方法可以使用。ImageDataGenerator是可以很方便地幫助你進行圖像資料增強的方法!

  • 訓練資料的資料格式通常會建議是Numpy Array格式,或者TensorFlow的Tensor格式,這兩種格式是輸入至神經網路需求的格式,然後陣列中元素的資料型態也要注意是不是為浮點數。

  • 如果是跟條件輸入有關的話,也要確定一下條件變量跟圖片是否正確成對。

    了解訓練資料集是訓練模型非常重要的要點!

第四步:建立生成模型類別

生成模型通常我會使用類別來建立,並且在初始化時進行一些超參數設定,這個用意是提高程式的可讀性,並且可以一目了然程式碼的超參數設定。以生成對抗網路 (GAN)為例通常要建立的就是生成模型判別模型以及對抗模型 (就是生成模型跟判別模型接在一起)、以及訓練方法跟一些其他方法例如訓練結束後儲存該次實驗的資料、訓練過程中儲存生成結果以便可以觀察到訓練過程的變化、自定義目標函數等等。

通常我在建立GAN時的程式會根據我整理的模板再更改,新增方法或者修改內容,使用模板可以加快程式的開發速度以及將開發邏輯保持一致,之後看程式碼才可以比較快理解。

class GAN:
    def __init__(self):
        self.genModel = self.build_generator()
        self.disModel = self.build_discriminator()
        self.advModel = self.build_adversarialmodel()
        if not os.path.isdir('./result'): #將訓練過程產生的圖片儲存起來
            os.mkdir('./result') #如果忘記新增資料夾可以用這個方式建立
    def build_generator(self):
        #建立生成器
        pass

    def build_discriminator(self):
        #建立判別器
        pass

    def build_adversarialmodel(self):
        #建立對抗模型
        pass
    
    def train(self, epochs):
        #建立訓練的方式
        pass

    def predict(self, num_images):
        #num_images是指一次生成多少張圖片
        #訓練結束後評估生成器性能
        pass
    
    def save_training_process(self):
        #將Loss變化,模型權重檔案等儲存起來
        pass

if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=100)
    gan.predict()

以上是我將模板簡化過後的程式,具體程式碼內會做的事情以及類別方法要傳遞的參數會在未來帶各位實作時更加詳細的說明。
如果你是對Python較熟悉、或者有寫過大型專案的話,可以使用抽象基底類別 (Abstract Base Class, ABC)來建立一個基礎的類別,並在之後的模型都繼承這個基礎類別。這麼做的好處是可以統一模型的建立方法,而且易於維護、可讀性高。當你不小心漏掉某些方法未定義時也會噴錯誤,提醒你有一些必要的方法還尚未被定義。~

第五步:建立模型、定義訓練方法

這一步我首先會建立生成、判別與對抗模型,也就是上面模板中的build_generatorbuild_discriminatorbuild_adversarialmodel,首先確定架構以後模型基本上建立步驟都類似:

  1. 建立輸入層,此時要注意輸入的shape是否符合預期。
  2. 建立中間隱藏層,可根據教學、文獻、自己的想法建立。如果是定義對抗網路的話則將生成器與判別器用Functional API接起來就好了。
  3. 建立輸出層,輸出層的激活函數要注意,生成器的激活函數根據正規化的方式設定,判別器通常都是使用Sigmoid。
  4. 將網路層用Model方法定義成完整模型。
  5. 定義模型優化器、目標函數 (損失函數)。
  6. 編譯模型,模型編譯完以後則可以確認模型的細節資訊等。

定義好模型後接著定義訓練方式,也就是train()裡面的內容,訓練方式不外乎就是:

  1. 將訓練資料準備好。通常訓練資料如果比較難處理我會習慣開新的檔案處理完再從處理資料用的檔案import進訓練用的檔案。
  2. 訓練判別器,並且儲存損失。
  3. 訓練生成器,並且儲存損失。

    此時會使用對抗模型訓練生成器,在對抗模型中會固定判別器的權重,使判別器的部分變得不可訓練,這樣就可以只訓練生成器訓練生出可以騙過判別器的內容。

  4. 在經過特定的訓練次數會用生成器生成結果並儲存,用於訓練後分析訓練情形用的。
  5. 訓練完成後儲存訓練結果,包括模型權重檔案、損失圖、訓練過程產生的照片等。

第六步:開始訓練

上面模板中可以看到程式中最底下有一段程式碼:

if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=100)
    gan.predict()

這部分是主程式的程式區段,在這邊就可以很清楚看到生成模型架構的設定,以及超參數等部分,也可以使用類別內建的方法在訓練完以後生成一次圖片,以審視訓練的最終成果。這時候也要開始除Bug,有時候Bug也會因為設備不同、使用的IDE不同等原因出現錯誤,但都很好解決。最麻煩的是模型演算時 shape出錯 ,這就要多花時間去看看程式碼再出錯的地方shape的變化是如何,除此之外也有許多大大小小的錯誤,就只能依靠耐心一步一步處理了。

第七步:審視訓練結果並改進

此時如果程式碼沒有錯誤,那恭喜你,需要除錯的已經除完了。接下來等著是更大的折磨😭,調整超參數,通常生成模型不可能一步就到位,除非是使用別人寫好沒問題的模型。

在此時通常會發現訓練不穩定、生成的圖片崩壞、梯度消失、模式崩潰等一堆問題,這些問題並不會噴錯誤,而是要自己根據訓練出來的結果評估,遇到這些問題該怎麼改進。那此時也不用擔心,在GitHub上或者一些論壇也可以常常看到大神們的教學,可以根據建議去修改模型,當超參數調整好了以後,訓練出來的模型足夠優秀時,真的會感覺到滿滿的成就感!
https://ithelp.ithome.com.tw/upload/images/20230912/20151029jIMaY6wV5O.png

結語

今天分享了我在做生成模型的一些習慣,不過撰寫程式的邏輯還是以各位的習慣為準,GAN模型的寫法真的非常多種。未來撰寫模型時基本上也會根據今天介紹的模板以這個大方向去寫,接下來就會開始進入實作了,在這幾天會會帶領各位一步一步搭建生成對抗網路模型,並介紹不同的GAN模型與差異。


上一篇
[Day8]:生成對抗網路 (GAN)原理介紹
下一篇
[Day10]:我的第一個GAN模型
系列文
生成式AI到底何方神聖?一窺生程式AI的真面目31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言